{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "FH3IRpYTta2v"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FH3IRpYTta2v"
      },
      "source": [
        "##### Copyright 2023 The IREE Authors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWGa71_Ct2ug",
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h5s6ncerSpc5"
      },
      "source": [
        "# Dynamic Shapes\n",
        "\n",
        "This notebook\n",
        "\n",
        "1. Creates a PyTorch program with dynamic shapes using [iree-turbine](https://github.com/iree-org/iree-turbine)'s advanced AOT toolkit\n",
        "2. Compiles the program to an IREE VM bytecode module\n",
        "3. Tests running the compiled VM module using IREE's runtime\n",
        "4. Downloads compilation artifacts for use with the native (C API) sample application"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s2bScbYkP6VZ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "fede6b57-a87b-42fc-ab71-57c1b8ff4ab3"
      },
      "source": [
        "#@title General setup\n",
        "\n",
        "import os\n",
        "import tempfile\n",
        "\n",
        "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
        "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using artifacts directory '/tmp/iree/colab_artifacts'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "#@title Uninstall existing packages\n",
        "#   This avoids some warnings when installing specific PyTorch packages below.\n",
        "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
      ],
      "metadata": {
        "id": "y9KOsqosg6Ms"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Install iree-turbine\n",
        "\n",
        "# Limit cell height.\n",
        "from IPython.display import Javascript\n",
        "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n",
        "\n",
        "!python -m pip install iree-turbine"
      ],
      "metadata": {
        "id": "SdCAvI3sqBO7",
        "outputId": "2d38c722-33cf-4210-89a7-bf4f42f92ab9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 300
        }
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting iree-turbine\n",
            "  Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n",
            "Collecting iree-base-compiler (from iree-turbine)\n",
            "  Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
            "Collecting iree-base-runtime (from iree-turbine)\n",
            "  Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
            "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n",
            "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n",
            "  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
            "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n",
            "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n",
            "  Attempting uninstall: ml_dtypes\n",
            "    Found existing installation: ml-dtypes 0.4.1\n",
            "    Uninstalling ml-dtypes-0.4.1:\n",
            "      Successfully uninstalled ml-dtypes-0.4.1\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Report version information\n",
        "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
        "\n",
        "!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
        "!iree-compile --version\n",
        "\n",
        "import torch\n",
        "print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Oj5I6R9LI7t_",
        "outputId": "deaa1abf-dc0e-49d8-d165-47d53592d94f"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Installed iree-turbine, Version: 3.1.0\n",
            "\n",
            "Installed IREE, compiler version information:\n",
            "IREE (https://iree.dev):\n",
            "  IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n",
            "  LLVM version 20.0.0git\n",
            "  Optimized build\n",
            "\n",
            "Installed PyTorch, version: 2.5.1+cu121\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Create a program using PyTorch + iree-turbine\n",
        "\n",
        "NOTE: as in other domains, providing more information to a compiler allows it\n",
        "to generate more efficient code. As a general rule, the slowest varying\n",
        "dimensions of program data like batch index or timestep are safer to treat as\n",
        "dynamic than faster varying dimensions like image x/y/channel. See\n",
        "[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the\n",
        "challenges imposed by dynamic shapes and one project's approach to addressing\n",
        "them."
      ],
      "metadata": {
        "id": "C3mhaullI940"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Define a sample `torch.nn.Module`.\n",
        "\n",
        "import iree.turbine.aot as aot\n",
        "\n",
        "class DynamicShapesModule(torch.nn.Module):\n",
        "  # reduce_sum_1d (dynamic input size, static output size)\n",
        "  #   tensor<?xi32> -> tensor<i32>\n",
        "  #   e.g. [1, 2, 3] -> 6\n",
        "  def reduce_sum_1d(self, values):\n",
        "    return torch.sum(values)\n",
        "\n",
        "  # reduce_sum_2d (partially dynamic input size, static output size)\n",
        "  #   tensor<?x3xi32> -> tensor<3xi32>\n",
        "  #   e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n",
        "  def reduce_sum_2d(self, values):\n",
        "    return torch.sum(values, 0)\n",
        "\n",
        "  # add_one (dynamic input size, dynamic output size)\n",
        "  #   tensor<?xi32>) -> tensor<?xi32>\n",
        "  #   e.g. [1, 2, 3] -> [2, 3, 4]\n",
        "  def add_one(self, values):\n",
        "    return values + 1"
      ],
      "metadata": {
        "id": "vsf9F4WxI_DX"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Export using FxProgramsBuilder.\n",
        "\n",
        "fxb = aot.FxProgramsBuilder(DynamicShapesModule())\n",
        "\n",
        "# Create a single dynamic export dimension.\n",
        "dynamic_x = torch.export.Dim(\"x\")\n",
        "# Example inputs with a mix of placeholder (dynamic) and static dimensions.\n",
        "example_1d = torch.empty(16, dtype=torch.int32)\n",
        "example_2d = torch.empty((16, 3), dtype=torch.int32)\n",
        "\n",
        "# Export reduce_sum_1d with a dynamic dimension.\n",
        "@fxb.export_program(\n",
        "    args=(example_1d,),\n",
        "    dynamic_shapes={\"values\": {0: dynamic_x}},\n",
        ")\n",
        "def reduce_sum_1d(module, values):\n",
        "    return module.reduce_sum_1d(values)\n",
        "\n",
        "# Export reduce_sum_2d with one dynamic dimension.\n",
        "@fxb.export_program(\n",
        "    args=(example_2d,),\n",
        "    dynamic_shapes={\"values\": {0: dynamic_x}},\n",
        ")\n",
        "def reduce_sum_2d(module, values):\n",
        "    return module.reduce_sum_2d(values)\n",
        "\n",
        "# Export add_one with a dynamic dimension.\n",
        "@fxb.export_program(\n",
        "    args=(example_1d,),\n",
        "    dynamic_shapes={\"values\": {0: dynamic_x}},\n",
        ")\n",
        "def add_one(module, values):\n",
        "    return module.add_one(values)\n",
        "\n",
        "export_output = aot.export(fxb)"
      ],
      "metadata": {
        "id": "cCy3nuLBKTAg"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from iree.compiler.ir import Context\n",
        "\n",
        "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n",
        "export_output.save_mlir(imported_mlir_path)\n",
        "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")\n",
        "\n",
        "# Inspect the IR.\n",
        "# Note the question marks for dynamic shapes in types, like `tensor<?xi32>`.\n",
        "print(\"\\nDynamic Shapes MLIR:\")\n",
        "!cat {imported_mlir_path}"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_OQIpOtNr4Gh",
        "outputId": "abe96b74-88de-4979-959c-cdfbc981b17c"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n",
            "\n",
            "Dynamic Shapes MLIR:\n",
            "module @module {\n",
            "  func.func @reduce_sum_1d(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[],si64> attributes {torch.assume_strict_symbolic_shapes} {\n",
            "    %none = torch.constant.none\n",
            "    %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?],si32>, !torch.none -> !torch.vtensor<[],si64>\n",
            "    return %0 : !torch.vtensor<[],si64>\n",
            "  }\n",
            "  func.func @reduce_sum_2d(%arg0: !torch.vtensor<[?,3],si32>) -> !torch.vtensor<[3],si64> attributes {torch.assume_strict_symbolic_shapes} {\n",
            "    %int0 = torch.constant.int 0\n",
            "    %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n",
            "    %false = torch.constant.bool false\n",
            "    %none = torch.constant.none\n",
            "    %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,3],si32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3],si64>\n",
            "    return %1 : !torch.vtensor<[3],si64>\n",
            "  }\n",
            "  func.func @add_one(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si32> attributes {torch.assume_strict_symbolic_shapes} {\n",
            "    %int1 = torch.constant.int 1\n",
            "    %int1_0 = torch.constant.int 1\n",
            "    %0 = torch.aten.add.Scalar %arg0, %int1, %int1_0 : !torch.vtensor<[?],si32>, !torch.int, !torch.int -> !torch.vtensor<[?],si32>\n",
            "    return %0 : !torch.vtensor<[?],si32>\n",
            "  }\n",
            "}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Test the imported program\n",
        "\n",
        "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n",
        "\n",
        "_See the [README](https://github.com/iree-org/iree/tree/main/samples/dynamic_shapes#instructions) for more details and example command line instructions._\n",
        "\n",
        "* _The \"imported MLIR\" (above) can be used by IREE's generic compiler tools_\n",
        "* _The \"binary\" can be saved and used by runtime applications_\n",
        "\n",
        "_The specific point at which you switch from Python to native tools will depend on your project._"
      ],
      "metadata": {
        "id": "z6w_Pbl6tUtJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Compile to a file on disk for usage outside of Python.\n",
        "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n",
        "export_output.compile(save_to=flatbuffer_path)\n",
        "print(f\"Wrote compiled program to path '{flatbuffer_path}'\")\n",
        "\n",
        "# Compile into memory for testing.\n",
        "binary = export_output.compile(save_to=None)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0PGyH1tvI_Ic",
        "outputId": "2ac3f280-1834-4d6c-f5b0-c9b470549ca7"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wrote compiled program to path '/tmp/iree/colab_artifacts/dynamic_shapes_cpu.vmfb'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import iree.runtime as ireert\n",
        "import numpy as np\n",
        "\n",
        "# Use the IREE runtime API to test the compiled program.\n",
        "config = ireert.Config(\"local-task\")\n",
        "vm_module = ireert.load_vm_module(\n",
        "    ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n",
        "    config,\n",
        ")\n",
        "\n",
        "print(vm_module.reduce_sum_1d(np.array([1, 10, 100], dtype=np.int32)).to_host())\n",
        "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30]], dtype=np.int32)).to_host())\n",
        "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30], [100, 200, 300]], dtype=np.int32)).to_host())\n",
        "print(vm_module.add_one(np.array([1, 10, 100], dtype=np.int32)).to_host())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9ilJY15BI_LD",
        "outputId": "f20aec4f-353e-4793-f9f1-066006d4471b"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "111\n",
            "[11 22 33]\n",
            "[111 222 333]\n",
            "[  2  11 101]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download compilation artifacts"
      ],
      "metadata": {
        "id": "3mizlpY9uJEW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ARTIFACTS_ZIP = \"/tmp/dynamic_shapes_colab_artifacts.zip\"\n",
        "\n",
        "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n",
        "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n",
        "\n",
        "# Note: you can also download files using Colab's file explorer\n",
        "try:\n",
        "  from google.colab import files\n",
        "  print(\"Downloading the artifacts zip file...\")\n",
        "  files.download(ARTIFACTS_ZIP)\n",
        "except ImportError:\n",
        "  print(\"Missing google_colab Python package, can't download files\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "id": "dgaXpdiWuGtx",
        "outputId": "94823b69-1095-4a97-9974-7d36fb3e2fb8"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n",
            "  adding: dynamic_shapes_cpu.vmfb (deflated 66%)\n",
            "  adding: dynamic_shapes.mlir (deflated 72%)\n",
            "Downloading the artifacts zip file...\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_7377c999-5cd8-4987-95c4-921d56969f65\", \"dynamic_shapes_colab_artifacts.zip\", 5472)"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}